{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T16:59:09.577221Z",
     "iopub.status.busy": "2022-06-14T16:59:09.576545Z",
     "iopub.status.idle": "2022-06-14T16:59:09.593060Z",
     "shell.execute_reply": "2022-06-14T16:59:09.592160Z",
     "shell.execute_reply.started": "2022-06-14T16:59:09.577162Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "%pylab inline\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T16:59:31.653844Z",
     "iopub.status.busy": "2022-06-14T16:59:31.653389Z",
     "iopub.status.idle": "2022-06-14T16:59:31.661109Z",
     "shell.execute_reply": "2022-06-14T16:59:31.659931Z",
     "shell.execute_reply.started": "2022-06-14T16:59:31.653805Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import nibabel as nib\n",
    "from nibabel.viewers import OrthoSlicer3D\n",
    "\n",
    "train_path = glob.glob('./讯飞初赛/Train/*/*')\n",
    "test_path = glob.glob('./讯飞初赛/Test/*')\n",
    "\n",
    "np.random.shuffle(train_path)\n",
    "np.random.shuffle(test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "id": "584f2188-cf45-4d72-abee-4dae0d362926",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:24.573222Z",
     "iopub.status.busy": "2022-06-14T17:05:24.572609Z",
     "iopub.status.idle": "2022-06-14T17:05:24.581119Z",
     "shell.execute_reply": "2022-06-14T17:05:24.580438Z",
     "shell.execute_reply.started": "2022-06-14T17:05:24.573166Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "DATA_CACHE = {}\n",
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, transform=None):\n",
    "        self.img_path = img_path\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        if self.img_path[index] in DATA_CACHE:\n",
    "            img = DATA_CACHE[self.img_path[index]]\n",
    "        else:\n",
    "            img = nib.load(self.img_path[index]) \n",
    "            img = img.dataobj[:,:,:, 0]\n",
    "            DATA_CACHE[self.img_path[index]] = img\n",
    "            \n",
    "        idx = np.random.choice(range(img.shape[0]), 130)\n",
    "        # idx.sort()\n",
    "        img = img[idx, :, :]\n",
    "        img = img.astype(np.float32)\n",
    "        \n",
    "        img /= 255.0\n",
    "        img -= 1\n",
    "        img = img.transpose([1,2,0])\n",
    "        \n",
    "        # print(img.shape)\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(image = img)['image']\n",
    "        \n",
    "        img = img.transpose([2,0,1])\n",
    "        return img,torch.from_numpy(np.array(int('AD' in self.img_path[index])))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:25.574953Z",
     "iopub.status.busy": "2022-06-14T17:05:25.574347Z",
     "iopub.status.idle": "2022-06-14T17:05:25.582953Z",
     "shell.execute_reply": "2022-06-14T17:05:25.582313Z",
     "shell.execute_reply.started": "2022-06-14T17:05:25.574903Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "        model = models.resnet18(True)\n",
    "        model.conv1 = torch.nn.Conv2d(130, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(512, 2)\n",
    "        self.resnet = model\n",
    "        \n",
    "    def forward(self, img):        \n",
    "        out = self.resnet(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "id": "47539bef-14a7-4881-85e9-0882ff295e6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:25.584093Z",
     "iopub.status.busy": "2022-06-14T17:05:25.583904Z",
     "iopub.status.idle": "2022-06-14T17:05:25.686183Z",
     "shell.execute_reply": "2022-06-14T17:05:25.685643Z",
     "shell.execute_reply.started": "2022-06-14T17:05:25.584076Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer):\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            print(loss.item())\n",
    "            \n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    return train_loss/len(train_loader)\n",
    "            \n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            \n",
    "            val_acc += (output.argmax(1) == target).sum().item()\n",
    "            \n",
    "    return val_acc / len(val_loader.dataset)\n",
    "\n",
    "def predict(test_loader, model, criterion):\n",
    "    model.eval()\n",
    "    val_acc = 0.0\n",
    "    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:25.956212Z",
     "iopub.status.busy": "2022-06-14T17:05:25.955660Z",
     "iopub.status.idle": "2022-06-14T17:05:25.965816Z",
     "shell.execute_reply": "2022-06-14T17:05:25.965158Z",
     "shell.execute_reply.started": "2022-06-14T17:05:25.956163Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(nni_path[:-10],\n",
    "            A.Compose([\n",
    "            A.RandomRotate90(),\n",
    "            A.RandomCrop(128, 128),\n",
    "            A.HorizontalFlip(p=0.5),\n",
    "            A.RandomContrast(p=0.5),\n",
    "            A.RandomBrightnessContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=True, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(nni_path[-10:],\n",
    "            A.Compose([\n",
    "            A.RandomCrop(128, 128),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path,\n",
    "            A.Compose([\n",
    "            A.RandomCrop(128, 128),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:26.619325Z",
     "iopub.status.busy": "2022-06-14T17:05:26.618755Z",
     "iopub.status.idle": "2022-06-14T17:05:26.865090Z",
     "shell.execute_reply": "2022-06-14T17:05:26.864481Z",
     "shell.execute_reply.started": "2022-06-14T17:05:26.619276Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = XunFeiNet()\n",
    "model = model.to('cuda')\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:27.573578Z",
     "iopub.status.busy": "2022-06-14T17:05:27.572985Z",
     "iopub.status.idle": "2022-06-14T17:05:43.841782Z",
     "shell.execute_reply": "2022-06-14T17:05:43.841173Z",
     "shell.execute_reply.started": "2022-06-14T17:05:27.573529Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7506477236747742\n",
      "1.0360616509569809 0.75 0.0\n",
      "4.075284957885742\n",
      "1.8050282004289329 0.65 0.5\n",
      "0.8950101733207703\n",
      "1.0674029730260373 0.4 0.5\n"
     ]
    }
   ],
   "source": [
    "for _  in range(3):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer)\n",
    "    val_acc  = validate(val_loader, model, criterion)\n",
    "    train_acc = validate(train_loader, model, criterion)\n",
    "    \n",
    "    print(train_loss, train_acc, val_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "id": "6787d12a-b923-4777-abf5-4d90904551d1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:05:46.718609Z",
     "iopub.status.busy": "2022-06-14T17:05:46.717995Z",
     "iopub.status.idle": "2022-06-14T17:06:36.949217Z",
     "shell.execute_reply": "2022-06-14T17:06:36.947602Z",
     "shell.execute_reply.started": "2022-06-14T17:05:46.718554Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = None\n",
    "\n",
    "for _ in range(10):\n",
    "    if pred is None:\n",
    "        pred = predict(test_loader, model, criterion)\n",
    "    else:\n",
    "        pred += predict(test_loader, model, criterion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:06:36.950578Z",
     "iopub.status.busy": "2022-06-14T17:06:36.950395Z",
     "iopub.status.idle": "2022-06-14T17:06:36.954739Z",
     "shell.execute_reply": "2022-06-14T17:06:36.954325Z",
     "shell.execute_reply.started": "2022-06-14T17:06:36.950560Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit = pd.DataFrame(\n",
    "    {\n",
    "        'uuid': [int(x.split('/')[-1][:-4]) for x in test_path],\n",
    "        'label': pred.argmax(1)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:06:36.955830Z",
     "iopub.status.busy": "2022-06-14T17:06:36.955667Z",
     "iopub.status.idle": "2022-06-14T17:06:37.083481Z",
     "shell.execute_reply": "2022-06-14T17:06:37.082439Z",
     "shell.execute_reply.started": "2022-06-14T17:06:36.955814Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit['label'] = submit['label'].map({1:'AD', 0: 'NC'})\n",
    "submit = submit.sort_values(by='uuid')\n",
    "submit.to_csv('submit2.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de0cae92-7686-4cea-93a9-83ba4356531e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
